import torch
import matplotlib.pyplot as plt
from itertools import combinations


def load_data_n2(size):
    """加载数据:生成虚拟二分类数据"""
    num_featurs = 2
    X0 = torch.normal(2, 1, (size, num_featurs))
    y0 = torch.zeros((size, 1))
    X1 = torch.normal(-2, 1, (size, num_featurs))
    y1 = torch.ones((size, 1))
    X = torch.cat((X0, X1), dim=0)
    y = torch.cat((y0, y1), dim=0)
    return X, y

def load_data_n3(size):
    """加载数据:生成虚拟多分类数据"""
    num_featurs = 2
    X0 = torch.normal(5, 1.3, (size, num_featurs))
    y0 = torch.zeros((size, 1))
    X1 = torch.normal(0, 1.2, (size, num_featurs))
    X1[:, 1] -= 2
    y1 = torch.ones((size, 1))
    X2 = torch.normal(0, 1.4, (size, num_featurs))
    X2[:, 0] -= 4
    X2[:, 1] += 4
    y2 = 2 * torch.ones((size, 1))
    X = torch.cat((X0, X1, X2), dim=0)
    y = torch.cat((y0, y1, y2), dim=0)
    return X, y

def plot_loss_acc(epochs, loss, acc):
    """绘制损失函数和准确率曲线"""
    plt.figure('acc & loss')
    plt.subplot(221)
    plt.plot(epochs, acc, color='r', label='acc') 
    plt.xlabel('epochs')
    plt.ylabel('acc')
    plt.title("acc")

    plt.subplot(222)
    plt.plot(epochs, loss, color='r', label='loss') 
    plt.xlabel('epochs')
    plt.ylabel('loss')
    plt.title("loss")
    return plt

def plot_decision_boundary(model, X, y, X_test=None, y_test=None):
    """绘制决策面"""
    x_min, x_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    y_min, y_max = X[:, 2].min() - 0.5, X[:, 2].max() + 0.5
    h = 0.01
    # 绘制网格
    xs, ys = torch.meshgrid(torch.arange(x_min, x_max, h), torch.arange(y_min, y_max, h), indexing='xy')
    # 生成与网格上所有点对应的分类结果
    n = len(xs.ravel())
    zs = model.predict(torch.cat([torch.ones((n, 1)), xs.ravel().reshape((-1, 1)), ys.ravel().reshape((-1, 1))], dim=1)).reshape(xs.shape)
    # 绘制contour
    plt.figure('decision boundary')
    acc = torch.tensor(0.)
    if X_test is not None:
        acc = model.accuracy(X_test, y_test)
    plt.title(f'Accuracy:{acc.numpy()*100:.2f}%')
    plt.contourf(xs, ys, zs, alpha=0.2)
    plt.scatter(X[:, 1], X[:, 2], c=y.numpy(), s=50, lw=0, cmap='RdYlGn', alpha=0.5)
    if X_test is not None and y_test is not None:
        plt.scatter(X_test[:, 1], X_test[:, 2], c=y_test.numpy(), s=20, lw=0, alpha=1.0)
    return plt


class LR:
    def __init__(self):
        self.W = None # 权重
        self.lr = 0.1
        self.num_epochs = 1
        self.losses = []
        self.acc = []
        self.epochs = []

    def sigmoid(self, X):
        """sigmoid函数"""
        return 1 / (1 + torch.exp(-X))

    def gradient_descent(self, X, y):
        """梯度下降优化"""
        y_pred = self.sigmoid(X @ self.W)
        gradient = -torch.sum((y-y_pred)*X, dim=0).reshape((-1, 1))
        self.W -= self.lr*gradient
        return self.W

    def loss(self, X, y):
        """损失函数"""
        y_pred = self.sigmoid(X @ self.W)
        log1 = torch.log(y_pred + 1e-10)
        log2 = torch.log(1 - y_pred + 1e-10)
        return -torch.sum(y*log1 + (1-y)*log2)

    def accuracy(self, X, y):
        """测试集准确率"""
        return torch.sum(self.predict(X) == y) / len(y)
    
    def fit(self, X, y, num_epochs=100, lr=0.1, make_cache=True, echo=True):
        """训练模型"""
        self.lr = lr
        self.num_epochs = num_epochs
        self.W = torch.normal(0, 1, (X.shape[1], 1))
        for epoch in range(self.num_epochs):
            self.gradient_descent(X, y)
            if epoch % 10 == 0:
                acc = self.accuracy(X, y)
                l = self.loss(X_train, y_train)    
                if make_cache:
                    self.losses.append(l)
                    self.acc.append(acc)
                    self.epochs.append(epoch)
                if echo:
                    print(f'epoch:{epoch+1}, acc:{acc:.4f}, loss:{l:.4f}')
        return self

    def predict(self, X):
        """预测样本标签"""
        y_pred = self.sigmoid(X @ self.W)
        y_pred[y_pred >= 0.5] = 1.0
        y_pred[y_pred < 0.5] = 0.0
        return y_pred


class LRModel(torch.nn.Module):
    """用torch简化实现LR model"""
    def __init__(self, num_features) -> None:
        super(LRModel, self).__init__()
        self.lr = torch.nn.Sequential(torch.nn.Linear(num_features, 1, bias=False), # 这里我们会自己处理bias，所以无需模型拟合 
                                      torch.nn.Sigmoid())

    def forward(self, X):
        """前向传播"""
        return self.lr(X)


class LRNet:
    """torch 简化实现LR训练过程"""
    def __init__(self):
        self.net = LRModel(2+1) # 这里加了偏置之后样本特征增加一列1
        self.epochs = []
        self.losses = []
        self.acc = []
    
    def predict(self, X):
        """预测"""
        y_pred = self.net(X)
        y_pred[y_pred >= 0.5] = 1.0
        y_pred[y_pred < 0.5] = 0.0
        return y_pred.detach()
    
    def accuracy(self, X, y):
        """计算模型准确率"""    
        return (torch.sum(self.predict(X) == y) / len(y)).detach()

    def fit(self, X, y, num_epochs=10, lr=0.1):
        """训练模型"""
        criterion = torch.nn.BCELoss() # 损失函数
        optimizer = torch.optim.SGD(self.net.parameters(), lr) # 随机梯度下降优化器
        for epoch in range(num_epochs):
            y_hat = self.net(X) # 前向传播
            loss = criterion(y_hat, y) # 计算损失
            optimizer.zero_grad() # 反向传播更新梯度之前需要先清空之前的梯度
            loss.backward() # 使用损失函数的梯度反向传播
            optimizer.step() # 更新梯度
            if epoch % 10 == 0:
                self.epochs.append(epoch)
                self.losses.append(loss.detach().numpy())
                self.acc.append(self.accuracy(X, y).numpy())
        return self


class MultiLR:
    """多分类LR"""
    def __init__(self):
        self.models = []
        self.labels = []

    def fit(self, X, y, lr=0.1, epochs=100, echo=False):
        """训练模型"""
        self.labels = y.unique() # 不同类别
        # OVA ONE-VS-ALL
        for label in self.labels:
            ycp = torch.clone(y)
            idx = torch.where(ycp == label)[0]
            jdx = torch.where(ycp != label)[0]
            ycp[idx] = 1.0
            ycp[jdx] = 0.0
            net = LR()
            net.fit(X, ycp, num_epochs=epochs, lr=lr, echo=echo)
            self.models.append(net)

    def predict(self, X):
        """预测"""
        # 投票
        vote = torch.zeros((X.shape[0], len(self.labels)))
        for label, model in zip(self.labels, self.models):
            y_pred = model.predict(X).reshape(-1)
            # 预测为当前类
            idx = torch.where(y_pred == 1.0)[0]
            vote[idx, int(label)] += 1
            # 预测为其他类
            ridx = torch.where(y_pred != 1.0)[0]
            cidx = torch.tensor([j for j in range(len(self.models)) if j != label.numpy()])
            row, col = torch.meshgrid(ridx, cidx, indexing='ij')
            vote[row, col] += 1
        return torch.argmax(vote, dim=1)

    def accuracy(self, X, y):
        """准确率"""
        return torch.sum(self.predict(X) == y.reshape(-1)) / len(y)


class MultiOVOLR:
    """多分类LR"""
    def __init__(self):
        self.models = []
        self.labels = []

    def fit(self, X, y, lr=0.1, epochs=100, echo=False):
        """训练模型"""
        self.labels = y.unique() # 不同类别
        # OVO ONE-VS-ONE
        n = 0
        for l1, l2 in combinations(self.labels, 2):
            n += 1
            i = torch.where(y == l1)[0]
            j = torch.where(y == l2)[0]
            idx = torch.cat([i, j], dim=0)
            ycp = torch.clone(y)
            ycp[i] = 1.0
            ycp[j] = 0.0
            net = LR()
            net.fit(X[idx, :], ycp[idx], num_epochs=epochs, lr=lr, echo=echo)
            self.models.append(net)

    def predict(self, X):
        """预测"""
        vote = torch.zeros((X.shape[0], len(self.labels)))
        for (l1, l2), model in zip(combinations(self.labels, 2), self.models):
            y_pred = model.predict(X).reshape(-1)
            # 预测为当前类
            idx = torch.where(y_pred == 1.0)[0]
            vote[idx, int(l1)] += 1
            jdx = torch.where(y_pred != 1.0)[0]
            vote[jdx, int(l2)] += 1
        return torch.argmax(vote, dim=1)

    def accuracy(self, X, y):
        """准确率"""
        return torch.sum(self.predict(X) == y.reshape(-1)) / len(y)


if __name__ == '__main__':
    # 准备数据集
    num_samples = 1000
    X_train, y_train = load_data_n2(num_samples)
    X_test, y_test = load_data_n2(num_samples//3)
    # 添加偏置b: X = [1 X]
    X_train = torch.cat([torch.ones((X_train.shape[0], 1)), X_train], dim=1)
    X_test = torch.cat([torch.ones((X_test.shape[0], 1)), X_test], dim=1)

    # 可视化数据，二维
    plt.figure()
    plt.scatter(X_train[:,1], X_train[:,2], c=y_train.numpy(), s=60, lw=0, cmap='RdYlGn', alpha=0.5)
    plt.scatter(X_test[:,1], X_test[:,2], c=y_test.numpy(), s=10, lw=1, alpha=1)

    # 训练模型
    model = LR()
    model.fit(X_train, y_train, num_epochs=1000, lr=0.001, make_cache=True, echo=False)

    # 模型结果可视化
    acc = model.accuracy(X_test, y_test)
    print('LR model 测试集准确率:', acc.numpy())
    plot_loss_acc(model.epochs, model.losses, model.acc)
    plot_decision_boundary(model, X_train, y_train, X_test=X_test, y_test=y_test)
    plt.show()

    # 简化实现版
    lr = LRNet()
    lr.fit(X_train, y_train, num_epochs=100, lr=0.01)
    print('LR(torch版) model 测试集准确率:', acc.numpy())
    plot_loss_acc(lr.epochs, lr.losses, lr.acc)
    plot_decision_boundary(lr, X_train, y_train, X_test, y_test)
    plt.show()

    X_train, y_train = load_data_n3(num_samples)
    X_test, y_test = load_data_n3(num_samples//3)
    # 添加偏置b: X = [1 X]
    X_train = torch.cat([torch.ones((X_train.shape[0], 1)), X_train], dim=1)
    X_test = torch.cat([torch.ones((X_test.shape[0], 1)), X_test], dim=1)
    # 可视化数据，二维
    plt.figure()
    plt.scatter(X_train[:,1], X_train[:,2], c=y_train.numpy(), s=60, lw=0, cmap='RdYlGn', alpha=0.5)
    plt.scatter(X_test[:,1], X_test[:,2], c=y_test.numpy(), s=10, lw=1, alpha=1)
    
    mlr = MultiLR()
    mlr.fit(X_train, y_train, lr=0.05, epochs=100)
    acc = mlr.accuracy(X_test, y_test)
    print('OVA model 测试集准确率:', acc.numpy())
    plot_decision_boundary(mlr, X_train, y_train, X_test, y_test)
    plt.show()

    molr = MultiOVOLR()
    molr.fit(X_train, y_train, lr=0.05, epochs=100)
    acc = molr.accuracy(X_test, y_test)
    print('OVO model 测试集准确率:', acc.numpy())
    plot_decision_boundary(mlr, X_train, y_train, X_test, y_test)
    plt.show()
